import pandas as pd
import numpy as np
from transformers import RobertaTokenizer, RobertaModel
from tokenizers import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing
from pathlib import Path
from tqdm.auto import tqdm
import re
import json
import pickle
import gc
import concurrent.futures
import multiprocessing
import torch
from transformers.tokenization_utils_base import BatchEncoding

import concurrent.futures
import pickle
from tqdm import tqdm

files_file_path="./tokenizers/"
with open(files_file_path+"kodovanie.json", "r", encoding="utf-8") as f:
    dictionary = json.load(f)

def dekoduj(tokens):
    decoded_tokens = []
    for token in tokens:
        for k, v in dictionary.items():
            if k in token:
                token = token.replace(k, v)
        decoded_tokens.append(token)
    return decoded_tokens


# Get file paths
paths = [Path(x) for x in Path('../20231210textclean/oscar_sk_subset_20240301').glob('**/*.txt') if "ipynb_checkpoints" not in str(x)]
paths = sorted(paths, key=lambda x: int(x.name.split('_')[-1].split('.')[0]))
paths = [str(path) for path in paths]
# paths = paths[:10]

from SKMT_lib.SKMT10 import SKMorfoTokenizer
tokenizer_skmt = SKMorfoTokenizer()

tokenizer_purebpe = RobertaTokenizer.from_pretrained(files_file_path+'pureBPE')


def mlm(tensor):
    rand = torch.rand(tensor.shape)
    mask_arr = (rand < 0.15) * (tensor > 2)
    for i in range(tensor.shape[0]):
        selection = torch.flatten(mask_arr[i].nonzero()).tolist()
        tensor[i, selection] = 4
    return tensor


def process_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        lines = file.readlines()
        processed_encodings_purebpe = []
        processed_encodings_skmt = []
        
        for line_index, line in enumerate(tqdm(lines, desc=f"Processing {file_path}")):
            line = line.strip()
            
            encoding_purebpe = tokenizer_purebpe(line, max_length=256, padding='max_length', truncation=True, return_tensors='pt')
            
            input_ids = mlm(encoding_purebpe.input_ids.detach().clone())
            processed_encodings_purebpe.append({
                'input_ids': input_ids,
                'attention_mask': encoding_purebpe.attention_mask,
                'labels': encoding_purebpe.input_ids
            })
            
            encoding_skmt = tokenizer_skmt.tokenize(line, max_length=256, return_tensors="pt", return_subword=False)
            
            input_ids = mlm(encoding_skmt.input_ids.detach().clone())
            processed_encodings_skmt.append({
                'input_ids': input_ids,
                'attention_mask': encoding_skmt.attention_mask,
                'labels': encoding_skmt.input_ids
            })

            
    return processed_encodings_purebpe, processed_encodings_skmt


def process_and_save_file(args):
    path, i = args
    encodings_purebpe = []
    encodings_skmt = []

    try:
        result = process_file(path)
        encodings_purebpe.extend(result[0])
        encodings_skmt.extend(result[1])
    except Exception as exc:
        print(f'Processing of {path} generated an exception: {exc}')
        
    input_ids_purebpe = torch.cat([res['input_ids'] for res in encodings_purebpe])
    mask_purebpe = torch.cat([res['attention_mask'] for res in encodings_purebpe])
    labels_purebpe = torch.cat([res['labels'] for res in encodings_purebpe])

    encodings_purebpe = {
        'input_ids': input_ids_purebpe,
        'attention_mask': mask_purebpe,
        'labels': labels_purebpe
    }   
    
    input_ids_skmt = torch.cat([res['input_ids'] for res in encodings_skmt])
    mask_skmt = torch.cat([res['attention_mask'] for res in encodings_skmt])
    labels_skmt = torch.cat([res['labels'] for res in encodings_skmt])

    encodings_skmt = {
        'input_ids': input_ids_skmt,
        'attention_mask': mask_skmt,
        'labels': labels_skmt
    } 
    
    return encodings_purebpe, encodings_skmt



# Number of parallel processes
num_processes = 96  # Adjust as needed

# Split paths into chunks for parallel processing
path_chunks = [(path, i) for i, path in enumerate(paths)]

# Parallel processing
with multiprocessing.Pool(processes=num_processes) as pool:
    encoding_chunks = list(tqdm(pool.imap(process_and_save_file, path_chunks), total=len(path_chunks), desc="Processing files"))

# Merge all BatchEncoding objects into one
merged_encoding = {
    'input_ids': torch.cat([enc[0]['input_ids'] for enc in encoding_chunks]),
    'attention_mask': torch.cat([enc[0]['attention_mask'] for enc in encoding_chunks]),
    'labels': torch.cat([enc[0]['labels'] for enc in encoding_chunks])
}

vysledok = BatchEncoding(merged_encoding)

with open("tokenizovane_texty/28032024_purebpe", 'wb') as file:
    pickle.dump(vysledok, file)

# Merge all BatchEncoding objects into one for skmt
merged_encoding_skmt = {
    'input_ids': torch.cat([enc[1]['input_ids'] for enc in encoding_chunks]),
    'attention_mask': torch.cat([enc[1]['attention_mask'] for enc in encoding_chunks]),
    'labels': torch.cat([enc[1]['labels'] for enc in encoding_chunks])
}

vysledok_skmt = BatchEncoding(merged_encoding_skmt)

with open("tokenizovane_texty/28032024_skmt", 'wb') as file:
    pickle.dump(vysledok_skmt, file)